;; gamma functions
(in-package :clml.nonparametric.statistics)

(defparameter +gamma-iter+ 9.5d0)
(defparameter +digamma-small+ 1d-6)

;;; bernoulli numbers
(eval-when (:compile-toplevel :load-toplevel :execute)
  (defconstant EULER-MASCHERONI -0.5772156649015328606065121d0)
  (defconstant B0 1d0)
  (defconstant B1 (- (/ 1d0 2d0)))
  (defconstant B2 (/ 1d0 6d0))
  (defconstant B4 (- (/ 1d0 30d0)))
  (defconstant B6 (/ 1d0 42d0))
  (defconstant B8 (- (/ 1d0 30d0)))
  (defconstant B10 (/ 5d0 66d0))
  (defconstant B12 (- (/ 691d0 2730d0)))
  (defconstant B14 (/ 7d0 6d0))
  (defconstant B16 (- (/ 3617d0 510d0)))
  (defconstant B18 (/ 43867d0 798d0))
  (defconstant B20 (- (/ 174611 330)))
  ;; more bernoulli numbers...
  )

(defun gamma-function (x)
  (setf x (dfloat x))
  (if (< x 0)
      (/ pi (* (sin (* pi x)) (exp (loggamma x))))
    (exp (loggamma x))))

(defun loggamma (x)
  (declare (optimize (speed 3) (debug 0) (safety 0))
           (type double-float x))
  (let ((v 1d0))
    (declare (type double-float v))
    (loop while (< x +gamma-iter+) do
      (setf v (* v x))
      (incf x))
    (let ((w (/ (* x x))))
      (declare (type double-float w))
      (+ (/ (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* #.(/ B20 (* 20 19)) w)
                                                               #.(/ B18 (* 18 17))) w)
                                                         #.(/ B16 (* 16 15))) w)
                                                   #.(/ B14 (* 14 13))) w)
                                             #.(/ B12 (* 12 11))) w)
                                       #.(/ B10 (* 10 9))) w)
                                 #.(/ B8 (* 8 7))) w)
                           #.(/ B6 (* 6 5))) w)
                     #.(/ B4 (* 4 3))) w)
               #.(/ B2 (* 2 1)))
            x)
         #.(* 0.5d0 (log (* 2d0 pi)))
         (- x)
         (* (- x 0.5d0) (the double-float (log x)))
         (- (the double-float (log v))))
      )))

(defun digamma (z)
  (declare (optimize (speed 3) (debug 0) (safety 0))
           (type double-float z))
  (let ((psi 0d0))
    (declare (type double-float psi))
    (if (< z +digamma-small+)
        (- EULER-MASCHERONI (/ z))
      (progn
        (loop while (< z +gamma-iter+) do
              (decf psi (/ z))
              (incf z))
        (let* ((invz (/ z))
               (w (* invz invz)))
          (declare (type double-float invz w))
          (+ psi
             (the double-float (log z))
             (* -0.5 invz)
             (* (- (* (- (* (- (* (- (* (- (* (- (* (- (* (- (* (- (* #.(/ B20 20) w)
                                                                   #.(/ B18 18)) w)
                                                             #.(/ B16 16)) w)
                                                       #.(/ B14 14)) w)
                                                 #.(/ B12 12)) w)
                                           #.(/ B10 10)) w)
                                     #.(/ B8 8)) w)
                               #.(/ B6 6)) w)
                         #.(/ B4 4)) w)
                   #.(/ B2 2)) w)))))))

(defun trigamma (x)
  (declare (optimize (speed 3) (debug 0) (safety 0))
           (type double-float x))
  (let ((v 0d0))
    (declare (type double-float v))
    (loop while (< x +gamma-iter+) do
          (incf v (/ (* x x)))
          (incf x))
    (let ((w (/ (* x x))))
      (declare (type double-float w))
      (+ (/ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* (+ (* B20 w)
                                                                  B18) w)
                                                            B16) w)
                                                      B14) w)
                                                B12) w)
                                          B10) w)
                                    B8 ) w)
                              B6) w)
                        B4) w)
                  B2) w)
            x)
         (/ x)
         (* w 0.5)
         v))))


(defun beta-function (x y)
  (declare (optimize (speed 3) (debug 0) (safety 0))
           (type double-float x y))
  (/ (* (gamma-function x) (gamma-function y))
     (gamma-function (+ x y))))
